import torch.nn as nn
class ClasificadorRecurrente(nn.Module):
def __init__(self, tipo_rnn, input_size, hidden_size, num_classes):
super().__init__()
# Selección Dinámica de Arquitectura
if tipo_rnn == 'LSTM':
self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
elif tipo_rnn == 'GRU':
self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
else:
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# x shape: (Batch, Seq_Len, Features)
# Propagación
if isinstance(self.rnn, nn.LSTM):
out, (hn, cn) = self.rnn(x) # LSTM devuelve tupla
else:
out, hn = self.rnn(x) # GRU/RNN devuelven tensor
# Clasificación basada en el último paso de tiempo
last_step = out[:, -1, :]
return self.fc(last_step)